Skip to content

Add Python distributed L4 to L3 dispatch#711

Open
PKUZHOU wants to merge 6 commits intohw-native-sys:mainfrom
PKUZHOU:feat/l4-l3-distributed-dispatch
Open

Add Python distributed L4 to L3 dispatch#711
PKUZHOU wants to merge 6 commits intohw-native-sys:mainfrom
PKUZHOU:feat/l4-l3-distributed-dispatch

Conversation

@PKUZHOU
Copy link
Copy Markdown
Contributor

@PKUZHOU PKUZHOU commented May 6, 2026

Summary

  • add Python-first gRPC/protobuf distributed dispatch package for L4 -> remote L3
  • integrate Worker.add_remote_worker() through a local PROCESS mailbox shim without C++/nanobind changes
  • add callable catalog, L3 daemon backend process, heartbeat, tensor-pool surface, examples, and docs

Tests

  • python -m pytest tests/ut/py/test_distributed tests/ut/py/test_worker/test_l4_recursive.py -q
  • python -m compileall -q python/simpler/distributed tests/ut/py/test_distributed examples/distributed/l4_l3_remote
  • git diff --check

Notes

  • current e2e remote dispatch covers scalar TaskArgs and callable execution
  • full remote tensor materialization/output write-back remains future work

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a distributed L4 to L3 dispatch system using gRPC and protobuf, enabling cross-host task execution. It introduces a callable catalog for remote function registration, a long-running L3 daemon with a fork-safe backend process, and a mailbox shim thread to integrate remote workers into the existing C++ scheduler. Feedback highlights critical issues regarding the transmission of raw memory pointers across host boundaries, which would lead to segmentation faults. Other recommendations include removing redundant logic in the catalog registration, ensuring consistent use of cloudpickle for deserialization, and improving error handling for unexpected backend process terminations.

tag = args.tag(i)
tensors.append(
dispatch_pb2.ContinuousTensorRef(
data=int(tensor.data),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Sending raw memory pointers (tensor.data) across host boundaries in a distributed system is incorrect. These addresses are local to the L4 process and will be invalid on the remote L3 node, leading to segmentation faults if accessed. A position-independent mechanism, such as handles or offsets into a shared tensor pool, should be used instead.

References
  1. To ensure shared memory is position-independent for future cross-process/cross-address-space communication, avoid storing absolute pointers (to stack or heap) within shared memory structures. Use relative offsets or process-local handles instead.

shape = tuple(int(x) for x in ref.shape)
dtype = DataType(int(ref.dtype))
tag = TensorArgType(int(ref.tag))
args.add_tensor(ContinuousTensor.make(int(ref.data), shape, dtype), tag)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Reconstructing a ContinuousTensor from a raw pointer received over the network is dangerous. In a distributed environment, ref.data is a pointer from a different address space (and likely a different host), making it invalid for local use.

References
  1. To ensure shared memory is position-independent for future cross-process/cross-address-space communication, avoid storing absolute pointers (to stack or heap) within shared memory structures. Use relative offsets or process-local handles instead.

Comment on lines +40 to +43
if callable_id is None:
self._next_id = max(self._next_id, cid + 1)
else:
self._next_id = max(self._next_id, cid + 1)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The self._next_id update is redundant here because self.install_from_payload(cid, version, payload) (called on line 39) already performs the exact same max(self._next_id, cid + 1) update on line 67. Additionally, the if/else branches are identical.

References
  1. Reuse existing helper functions or methods instead of duplicating their logic. This improves consistency, maintainability, and reduces the chance of introducing bugs.


def _loads_with_allowlist(payload: bytes, allowed_modules: Optional[Tuple[str, ...]]) -> Callable:
if allowed_modules is None:
return pickle.loads(payload)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since cloudpickle (aliased as _pickle_impl) is used for serialization in register, it should also be used for deserialization here to ensure compatibility, especially for lambdas and closures which standard pickle cannot handle.

Suggested change
return pickle.loads(payload)
return _pickle_impl.loads(payload)

Comment on lines +104 to +106
with self._backend_lock:
self._backend_conn.send(msg)
ok, payload = self._backend_conn.recv()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

self._backend_conn.recv() will raise an EOFError if the backend process terminates unexpectedly (e.g., due to a crash). This should be handled to provide a more descriptive error message to the RPC client rather than letting the gRPC handler thread fail with an unhandled exception.

Suggested change
with self._backend_lock:
self._backend_conn.send(msg)
ok, payload = self._backend_conn.recv()
try:
with self._backend_lock:
self._backend_conn.send(msg)
ok, payload = self._backend_conn.recv()
except EOFError:
raise RuntimeError("L3 daemon backend process terminated unexpectedly") from None

Comment on lines +130 to +131
def sleep_poll_interval() -> None:
time.sleep(0.0005)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function sleep_poll_interval appears to be unused in the current implementation.

References
  1. Remove unused exported variables to prevent them from causing errors if their values become incorrect in different environments.

@jvjhfhg
Copy link
Copy Markdown
Collaborator

jvjhfhg commented May 7, 2026

examples/distributed/l4_l3_remote/l4_master.py seems to imply that remote execution can mutate L4-local Python state:

counter = Counter()

def l3_sub(task_args):
    counter.add(task_args.scalar(0))

...
print(f"remote counter={counter.value}")
return 0 if counter.value == 7 else 1

But l3_sub is pushed to the L3 daemon through the callable catalog, so the captured counter is serialized with the closure. The remote backend/sub-worker mutates its own deserialized/forked copy, not the original counter in the L4 process.

Meanwhile, the current dispatch response only reports success/failure:

inner.run(orch_fn, args, cfg)
return dispatch_pb2.DispatchResp(task_id=req.task_id, error_code=0), inner

and RemoteWorkerProxy.dispatch() only checks error_code; it does not read or materialize DispatchResp.output_tensors.

The tests (test_l4_l3_remote.py) avoid this by using externally visible state such as a file/shared-memory counter, which verifies that remote execution happened, but it does not demonstrate a real distributed result-return path. For a cross-host example, this feels misleading.

The example should either avoid expecting L4-local closure state to change, or explicitly use/document an external side effect until DispatchResp.output_tensors or another result-return mechanism is implemented.

Copy link
Copy Markdown
Contributor Author

@PKUZHOU PKUZHOU left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Retracted: this review was posted from the wrong GitHub identity and will be reposted from the intended project-local account with a self-contained summary.

Copy link
Copy Markdown
Contributor

@uv-xiao uv-xiao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation review summary for the L4-L3 distributed dispatch PR. I focused on behavioral and semantic issues rather than CI/style.

Main concerns:

  1. Callable catalog versioning is internally inconsistent: PullCallable(version=0) resolves latest payload bytes but returns version 0, which can fail install-time version validation.
  2. The documented L4/L3 example constructs TaskArgs in an invalid scalar-before-tensor order.
  3. The dispatch proto exposes both legacy address-like tensor_args and remote data-plane tensor_refs; the contract should be narrowed or validated so raw L4 addresses are not accidentally treated as remote execution pointers.
  4. Tensor input staging is currently selected by an internal byte threshold. The distributed L4 programming model should prefer an explicit handle/remote-storage path chosen by the program, with inline bytes only as an explicit small-message/test path.
  5. OUTPUT and OUTPUT_EXISTING staging reads and transfers old local buffer bytes even though output-only tensors do not semantically consume prior contents.
  6. INOUT currently has copy-in/copy-out semantics and is excluded from the RXE local-output fast path, so it should not be described as shared or in-place remote memory.
  7. Tensor-ref dispatch creates an ephemeral backend Worker(level=3) per request, which differs materially from persistent local L3 worker reuse and needs rationale/benchmarking or a plan for persistent child-visible tensor storage.
  8. L3 daemon dispatch goes through serialized blocking foreground-to-backend IPC after the gRPC call; the process split and overhead need to be justified or simplified.
  9. Heartbeat checks only foreground gRPC liveness, not backend/runtime/device/TensorPool readiness.
  10. Remote callables use cloudpickle/pickle, whose semantics differ from local fork/COW callable inheritance; the remote callable contract and trusted-cluster assumption should be explicit.

context.abort(grpc.StatusCode.NOT_FOUND, str(e))
return dispatch_pb2.CallablePayload(
callable_id=request.callable_id,
version=request.version,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

version=0 is treated as "latest" by export_payload(), but this response returns the literal request version. That can return valid payload bytes with version=0, and install_from_payload() validates version == hash(payload), so a caller installing the response can fail. Please return the resolved payload version, or remove latest-version semantics from this RPC.

def l4_orch(orch, task_args, config):
for value in (2, 5):
sub_args = TaskArgs()
sub_args.add_scalar(value)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example calls add_scalar() before add_tensor(), but TaskArgs rejects adding tensors after scalars. Running the documented example fails with RuntimeError: TaskArgs: cannot add tensor after scalar. Please add tensors before scalars, or change the TaskArgs contract/tests if interleaving is intended.

uint64 callable_version = 3;
bytes config_blob = 4;
repeated uint64 scalar_args = 5;
repeated ContinuousTensorRef tensor_args = 6;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor_args carries ContinuousTensorRef.data, an address-like value from the sender process. That is not a valid execution pointer on remote L3. Since tensor_refs is the real remote data-plane schema, please clarify whether tensor_args remains supported; if not, reject non-empty tensor_args in the daemon or remove it from the active dispatch path.

local_output_regions.append(region)
continue
data = ctypes.string_at(int(tensor.data), nbytes) if nbytes else b""
if nbytes <= self._tensor_inline_threshold:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes inline-vs-TensorPool staging an implicit byte-threshold decision. I think the intended L4 programming model should be explicit: the L4 program allocates/registers remote tensor storage, gets a TensorHandle/remote ref, and passes that as the tensor argument. Inline bytes can remain as an explicit small-message/test path, but the normal distributed tensor path should not silently switch based only on size.

refs.append(ref)
local_output_regions.append(region)
continue
data = ctypes.string_at(int(tensor.data), nbytes) if nbytes else b""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reads bytes for every tensor, including OUTPUT/OUTPUT_EXISTING unless the large RXE output path is selected. Old output-buffer contents are not semantic inputs, so small output tensors and fallback output paths send irrelevant bytes to L3. Please separate input staging from output allocation/writeback.

def _should_stage_local_output(self, tag, nbytes: int) -> bool: # noqa: ANN001
return (
self._tensor_transport in {"rxe", "auto"}
and getattr(tag, "name", "") in {"OUTPUT", "OUTPUT_EXISTING"}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The local-output RXE fast path excludes INOUT, so large INOUT is copy-in/copy-out rather than bidirectional RXE. That is acceptable as an MVP if documented, but it should not be described as shared or in-place remote memory.

else:
args = decode_task_args(req.tensor_args, req.scalar_args)
if req.tensor_refs:
run_inner = worker_factory()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tensor-ref dispatch creates a fresh backend Worker(level=3) per request, initializes it, runs it, and closes it. This is likely needed so newly materialized mmap buffers exist before L3 children fork, but it is a major lifecycle/performance difference from a persistent L3 worker. Please document this tradeoff and the plan for persistent worker + child-visible tensor storage.

def _backend_call(self, msg):
if self._backend_conn is None:
raise RuntimeError("L3 daemon backend is not running")
with self._backend_lock:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every foreground RPC is serialized through _backend_lock and a blocking Pipe send/recv before backend execution. If this process split is required for fork-safety with gRPC threads, please document it and provide overhead numbers; otherwise consider a direct backend RPC/event loop.

)

def Heartbeat(self, request, context): # noqa: N802, ANN001
return dispatch_pb2.Health(ok=True, message="ok")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This heartbeat only proves foreground gRPC liveness. It does not check backend process health, worker initialization, runtime/device readiness, TensorPool capacity, or selected transport. Please either rename/document it as liveness only, or add a deeper readiness RPC.

self._allowed_modules = allowed_modules

def register(self, fn: Callable, callable_id: Optional[int] = None) -> tuple[int, int]:
payload = _pickle_impl.dumps(fn)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remote callables are serialized with cloudpickle/pickle, which is not equivalent to local fork/COW callable inheritance. Captured mutable state, raw pointers, file descriptors, sockets, locks, device contexts, imports, and side effects expected to be visible at L4 can behave differently. Please document the remote callable contract and the trusted-cluster assumption.

Copy link
Copy Markdown
Contributor

@uv-xiao uv-xiao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second review round: callable semantics and the L4/L3 remote example.

The main issue is that the public API is ABI-uniform but not semantically uniform. submit_next_level(callable, ...) accepts a uint64, but the semantic meaning changes by level/target: L3-to-L2 expects a chip callable handle, L4-to-L3 expects a Python orchestration callable id, and remote L3 treats the value as a catalog id. Worker.register(fn) also stores both SubWorker callables (fn(task_args)) and orchestration callables (fn(orch, task_args, config)) in one untyped namespace. The current example makes this hard to understand because w4.register(l3_sub) registers a callable intended for a remote L3 SubWorker, while w4.register(l3_orch) registers the remote L3 orchestration function.

I think the minimum direction should be typed callable registration/handles, for example register_sub(...), register_orch(...), and a distinct chip callable handle. The internal slot can still carry a compact integer, but the public API should prevent passing a subworker callable id where a next-level orchestration callable is expected.

endpoints = [item.strip() for item in args.remotes.split(",") if item.strip()]

def l3_sub(task_args):
output = task_args.tensor(1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

task_args contains one tensor in this example, added below at line 36, so this should be task_args.tensor(0). As written, the subworker tries to read tensor index 1 and the example cannot demonstrate the intended remote L3 writeback behavior.

sub_cid = w4.register(l3_sub)

def l3_orch(orch, task_args, config):
orch.submit_sub(sub_cid, task_args)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes the example hard to interpret. submit_sub() targets an L3 SubWorker, not the next-level remote L3 worker itself. The actual path is L4 submit_next_level(l3_cid) -> remote L3 runs l3_orch -> remote L3 submit_sub(sub_cid) -> SubWorker mutates the tensor. If this example is meant to explain L4-to-remote-L3 dispatch, please either do the mutation directly in l3_orch, or make the example explicitly about "L4 dispatch to remote L3, then remote L3 dispatch to its SubWorker".

current.value += int(task_args.scalar(0))

w4 = Worker(level=4, num_sub_workers=0)
sub_cid = w4.register(l3_sub)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is semantically confusing: l3_sub is not a W4 SubWorker callable because w4 has num_sub_workers=0; it is intended to be shipped through the L4 catalog and later used by the remote L3 daemon as an L3 SubWorker callable. That works only because Worker.register() is an untyped global callable-id table. Please make the callable kind explicit, for example register_sub(...) vs register_orch(...), or otherwise document and validate that this id is meant for the remote child's SUB namespace.

def l4_orch(orch, task_args, config):
for value in (2, 5):
sub_args = TaskArgs()
sub_args.add_scalar(value)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This adds a scalar before adding a tensor, but TaskArgs requires tensors before scalars. The example should add the output tensor first and then the scalar. This concrete bug also obscures the higher-level callable-semantics issue because the example can fail before reaching the remote dispatch path.

Comment thread python/simpler/worker.py
raise RuntimeError("Worker.register() must be called before init()")
cid = len(self._callable_registry)
self._callable_registry[cid] = fn
if self._distributed_catalog is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

register() stores every Python callable in one _callable_registry, but the callable contracts are different depending on how the id is later used. submit_sub(cid, ...) expects fn(task_args), while L4 submit_next_level(cid, ...) expects a child orchestration callable shaped like fn(orch, task_args, config). The registry should not be untyped here. Please split this into explicit APIs such as register_sub() and register_orch(), or return typed handles so the wrong callable kind cannot be submitted to the wrong worker type.

Comment thread python/simpler/worker.py
from .distributed.catalog import Catalog # noqa: PLC0415

self._distributed_catalog = Catalog()
for cid, fn in self._callable_registry.items():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The distributed catalog mirrors every entry from _callable_registry, but it does not preserve whether an entry is a SubWorker callable or a next-level orchestration callable. That is why the L4 example can register both l3_sub and l3_orch on w4 and rely on later submit paths to decide their meaning. For remote dispatch this should be explicit in the catalog payload or handle type; otherwise a wrong id can cross the network and fail only later as a signature/runtime mismatch.

Comment thread python/simpler/worker.py
try:
args = _read_args_from_mailbox(buf)
cfg = _read_config_from_mailbox(buf)
proxy.dispatch(int(cid), args, cfg)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This forwards the mailbox callable field as a remote catalog id. That is the core semantic mismatch: the same submit_next_level(callable, ...) slot can mean a chip callable handle for L3-to-L2, but a Python orchestration/catalog id for L4-to-L3. The public API should expose a typed next-level orchestration handle here, even if the mailbox representation remains a uint64 internally.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants